from openai import OpenAI
import torch
import argparse
from statistics import median
import os
import json
from tqdm import tqdm
import sys
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
from symeval import EvaluatorMathBatch

symeval_evaluator = EvaluatorMathBatch()
print("actual_args:", sys.argv)
# import tiktoken
from transformers import AutoTokenizer

signal_pool = {
            1: 'Alternately,',
            2: 'Alternatively,',
            3: 'Alternative:',
            4: 'Wait, alternate approach:',
            5: 'Wait, let me try another method',
            6: 'Wait, let me try another approach',
            7: 'Let me try another method',
            8: 'Let me try another approach',
            9: 'Wait, another approach:',
            10: 'Wait, alternatively,',
            11: 'Alternative approach:'
        }

# ============================================================
#                          settings
# ============================================================
def _parse_args():
    parser = argparse.ArgumentParser(description='')
    #
    parser.add_argument('--model_path', type=str, help='', required=True)
    parser.add_argument('--port', type=int, help='', required=True)
    # data
    parser.add_argument('--data_path', type=str, help='', required=True)
    parser.add_argument('--start_id', type=int, default=0, help='')
    parser.add_argument('--end_id', type=int, default=-1, help='')
    parser.add_argument('--question_key', type=str, default='question')
    parser.add_argument('--question_id', type=str, default='id')
    parser.add_argument('--gt_answer_key', type=str, default='answer')
    parser.add_argument('--only_wrong', type=bool, default=False)
    #
    parser.add_argument('--max_interrupts', type=int, default=1)
    parser.add_argument('--deepen_prompt_version', type=str, default='')
    parser.add_argument('--interrupt_signals_version', type=str, default='')
    #
    parser.add_argument('--system_prompt_version', type=str, default='')
    parser.add_argument('--question_prefix_version', type=str, default='')
    parser.add_argument('--question_suffix_version', type=str, default='')
    #
    parser.add_argument('--temperature', type=float, default=None)
    parser.add_argument('--top_p', type=float, default=None)
    parser.add_argument('--top_k', type=int, default=None)
    parser.add_argument('--repetition_penalty', type=float, default=None)
    parser.add_argument('--max_tokens', type=int, default=32768)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--extend_max_tokens', type=int, default=None)
    # divide_step_method
    parser.add_argument('--divide_step_method', type=str, default='v0', help='')
    parser.add_argument('--threshold', type=str, default='v0', help='')
    # prm
    parser.add_argument('--prm_model_path', type=str, default=None, help='')
    parser.add_argument('--prm_model_dtype', type=str, default=None, help='')
    parser.add_argument('--prm_gpu_memory_utilization', type=float, default=0.95, help='')
    parser.add_argument('--prm_max_model_len', type=int, default=None, help='')
    parser.add_argument('--prm_max_seq_len_to_capture', type=int, default=32768, help='')
    parser.add_argument('--prm_pipeline_parallel_size', type=int, default=1, help='')
    # sampling (prm)
    # save
    parser.add_argument('--save_root_dir', type=str, default='', required=True)

    args = parser.parse_args()
    return args
# ============================================================
#                           step_divide_setup
# ============================================================
def setup_step_divide():
    client = OpenAI(
        api_key = "xxxxxxxxxxxxxxxxxx", # Replace with your actual API key
        base_url = "xxxxxxxxxxxxxxxxxxxxxxxx", # Replace with your actual base URL
    )
    return client
# ============================================================
#                           prm_setup
# ============================================================
def setup_prm(args):
    print('\n************************************')
    print('Setup prm')
    print('************************************')
    tokenizer = setup_tokenizer(model_path = args.prm_model_path)
    model = setup_model(
        model_path = args.prm_model_path,
        model_dtype = args.prm_model_dtype,
        gpu_memory_utilization = args.prm_gpu_memory_utilization,
        max_model_len = args.prm_max_model_len,
        max_seq_len_to_capture = args.prm_max_seq_len_to_capture + args.extend_max_tokens * args.max_interrupts,
        pipeline_parallel_size = args.prm_pipeline_parallel_size
    )
    sampling_params = None
    return model,tokenizer, sampling_params
def setup_model(
        model_path,
        model_dtype,
        gpu_memory_utilization,
        max_model_len,
        max_seq_len_to_capture,
        pipeline_parallel_size,

    ):
    print('==> Setup prm model ...')
    # Input the model name or path. Can be GPTQ or AWQ models.
    kwargs = {}
    if model_dtype is not None:
        kwargs['dtype'] = model_dtype
    if gpu_memory_utilization is not None:
        kwargs['gpu_memory_utilization'] = gpu_memory_utilization
        # print(f'set gpu_memory_utilization to {args.gpu_memory_utilization}')
    if max_model_len is not None:
        kwargs['max_model_len'] = max_model_len
    if max_seq_len_to_capture is not None:
        kwargs['max_seq_len_to_capture'] = max_seq_len_to_capture
    print(f'Model kwargs: {kwargs}')

    available_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
    num_gpu = len(available_gpus)
    tensor_parallel_size = num_gpu // pipeline_parallel_size
    print('tensor_parallel_size: ', tensor_parallel_size)
    print('pipeline_parallel_size: ', pipeline_parallel_size)
    if 'math-shepherd' in model_path:
        pooler_config = PoolerConfig(
            pooling_type="STEP",
            step_tag_id=12902,
            returned_token_ids=[648, 387]
        )
        model = LLM(
            model=model_path,
            tensor_parallel_size=tensor_parallel_size,
            pipeline_parallel_size=pipeline_parallel_size,
            trust_remote_code=True,
            task="reward",
            override_pooler_config=pooler_config,
            **kwargs,
        )

    else:
        model = LLM(
            model=model_path,
            tensor_parallel_size=tensor_parallel_size,
            pipeline_parallel_size=pipeline_parallel_size,
            trust_remote_code=True,
            task = 'reward',
            **kwargs,
        )
    return model
def setup_tokenizer(model_path):
    print('==> Setup tokenizer ...')
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    return tokenizer
# ============================================================
#                           get_score
# ============================================================
def divide_steps(response,steps_divide_model,divide_step_method):
    prompt ='''
    You are an expert in analyzing and decomposing complex problem-solving processes, especially in mathematics. \n\n---\n\n### Task:  \nYour task is to divide a long and systematic thinking process (provided below) into coherent, sequential steps. Each step should represent a complete phase of reasoning, such as problem analysis, exploration, reassessment, or verification. Ensure **no content is omitted** between steps, and the entire process is covered from start to finish.\n\n---\n\n### Output Format:  \nPresent the steps in the following structured XML-like format:  \n\n```XML\n<step number=\"step id\">  \n    <objective> Purpose of this step </objective>  \n    <start> First exact sentence of this step in the given thinking process </start>  \n    <end> Last exact sentence of this step in the given thinking process </end>  \n</step>  \n```\n\n---\n\n### Key Requirements:  \n1. **Continuity Preservation**:  \n   - The `end` sentence of step *i* must **immediately precede** the `start` sentence of step *i+1* in the original text.  \n   - No sentences should be skipped or omitted between steps.  \n\n2. **Complete Coverage**:  \n   - The last step’s `end` must be the **very last sentence** of the entire thinking process.  \n\n3. **Step Objectives**:  \n   - Label each step’s purpose clearly (e.g., \"Initial analysis,\" \"Error correction,\" \"Explore different ideas\").  \n   - For backtracking/reassessment, use objectives like \"Re-evaluating approach due to X.\"  \n\n---\n\n### Strict Validation Rules:  \n1. **Text Continuity Check**:  \n   - For all steps except the last, the `end` of step *i* must be the **direct predecessor** of the `start` of step *i+1* in the original text.  \n   - Example: If step 1 ends with *\"Now I’ll try Method A,\"* step 2 must start with the **very next sentence** in the original text (e.g., *\"First, I apply Method A to the equation...\"*).  \n\n2. **Final Step Coverage**:  \n   - The `end` of the final step **must match** the last sentence of the entire thinking process.  \n\n---\n\n### Instructions:  \n1. **Read the entire thinking process carefully**: Identify logical segments where the problem-solver shifts focus (e.g., from analyzing to solving or reflecting, or exploring, or summarizing).  \n2. **Define each step**: Assign a unique step number and describe its purpose (objective). \n3. **Adjust step granularity adaptively**: Smaller steps for detailed reasoning, larger steps for broader phases.  \n4. **Extract the text**: Mark the exact beginning and ending sentences of each step in the original text.  \n5. **Ensure every sentence is included** in exactly one step, with no overlaps or gaps.  \n6. **Explicitly verify** the key requirements above before finalizing the output.  \n\n---\n\n### Thinking Process to Decompose (Input):
    '''
    extract_success = True
    count = 0
    if '<think>' in response:
        response = response.split('<think>')[1]
    if '</think>' in response:
        response = response.split('</think>')[0]
    pred_out = response
    if divide_step_method == 'v2':
        print('use method v2')
        extract_steps = extract_steps_content_v2(pred_out)
        return extract_steps
    elif divide_step_method == 'v3':
        print('use method v3')
        extract_steps = extract_steps_content_v3(pred_out)
        return extract_steps
    elif divide_step_method == 'v4':
        print('use method v4')
        extract_steps = extract_steps_content_v4(pred_out)
        return extract_steps
    while extract_success and count <1:
            count += 1
            response = steps_divide_model.chat.completions.create(
                model="xxxx", # Replace with your actual model name
                messages=[
                    {"role": "user", "content": f"{prompt}\n{pred_out}"}
                ]
            )
            if divide_step_method == 'v0':
                extract_steps = extract_steps_content_v0(pred_out, response.choices[0].message.content)
            elif divide_step_method == 'v1':
                extract_steps = extract_steps_content_v1(pred_out, response.choices[0].message.content)
            else:
                raise NotImplementedError
            if extract_steps != False:
                extract_success = False
                break
    return extract_steps

def check_score(messages,output_inherited_by_next_round,prm_model,prm_tokenizer,stpes_divide_model,question,divide_step_method,cached_response="", cached_steps=None, cached_scores=None):
    if messages[-1]['role'] == 'user':
        response = output_inherited_by_next_round
    elif messages[-1]['role'] == 'assistant' and messages[-2]['role'] == 'user':
        response =  messages[-1]['content'] + output_inherited_by_next_round
    else:
        raise ValueError
    can_reuse = (
        cached_response is not None and len(cached_response) > 0 and
        cached_steps is not None and len(cached_steps) > 0 and
        cached_scores is not None and len(cached_scores) > 0 and
        response.startswith(cached_response)
    )
    if can_reuse:
        print("detected reusable cached content")
        # deal with the new content
        new_content = response[len(cached_response):]
        if not new_content.strip():
            print("No new content, returning cached scores directly")
            return cached_scores, cached_steps

        # divide the new content into steps
        new_steps = divide_steps(new_content, stpes_divide_model, divide_step_method)
        if new_steps == False:
            print("cannot divide new content into steps, returning cached scores and steps")
            can_reuse = False
        else:
            input_steps, input_messages, steps_end_pos = build_input(new_steps, question)
            combined_steps = cached_steps + input_steps
            query_id = prm_tokenizer.apply_chat_template(
                input_messages,
                tokenize=True,
                add_generation_prompt=True
            )
            outputs_score = cached_scores.copy()
            print(f"reusing cached scores for {len(cached_scores)} steps, adding scores for {len(input_steps)} new steps")
            for step_idx in range(len(cached_steps) + 1, len(combined_steps) + 1):
                # here calculate the score for the new step
                responses = "\n\n".join(input_steps[:step_idx]) + "\n\n"
                answer_tokens = prm_tokenizer(responses)['input_ids']
                answer_tokens += [prm_tokenizer.eos_token_id]
                QA_ids = query_id + answer_tokens
                token_ids_list = QA_ids.tolist() if isinstance(QA_ids, torch.Tensor) else QA_ids
                token_list_len = len(token_ids_list)
                max_seq_len = prm_model.llm_engine.model_config.max_seq_len_to_capture
                i = 1
                while token_list_len > max_seq_len:
                    responses = "\n\n".join(input_steps[i:step_idx]) + "\n\n"
                    answer_tokens = prm_tokenizer(responses)['input_ids']
                    answer_tokens += [prm_tokenizer.eos_token_id]
                    QA_ids = query_id + answer_tokens
                    token_ids_list = QA_ids.tolist() if isinstance(QA_ids, torch.Tensor) else QA_ids
                    token_list_len = len(token_ids_list)
                    i+=1
                assert len(token_ids_list) <= max_seq_len
                (output,) = prm_model.encode(prompt_token_ids= token_ids_list)
                score = torch.sigmoid(output.outputs.data[-1]).cpu().item()
                outputs_score.append(score)

            print(f'reuse {len(cached_scores)} scores, add {len(outputs_score) - len(cached_scores)} new scores')
            assert len(outputs_score) == len(combined_steps)
            return outputs_score, combined_steps




    steps = divide_steps(response,stpes_divide_model,divide_step_method)
    if steps == False:
        print("Error: Failed to extract steps.")
        return False,False
    input_steps, input_messages,steps_end_pos = build_input(steps,question)
    query_id = prm_tokenizer.apply_chat_template(
        input_messages,
        tokenize=True,
        add_generation_prompt=True
    )
    print(input_messages)
    outputs_score = []
    for step_idx in range(1, len(input_steps)+1):
        responses = "\n\n".join(input_steps[:step_idx]) + "\n\n"
        answer_tokens = prm_tokenizer(responses)['input_ids']
        answer_tokens += [prm_tokenizer.eos_token_id]
        QA_ids = query_id + answer_tokens
        token_ids_list = QA_ids.tolist() if isinstance(QA_ids, torch.Tensor) else QA_ids
        token_list_len = len(token_ids_list)
        max_seq_len = prm_model.llm_engine.model_config.max_seq_len_to_capture
        i = 1
        while token_list_len > max_seq_len:
            responses = "\n\n".join(input_steps[i:step_idx]) + "\n\n"
            answer_tokens = prm_tokenizer(responses)['input_ids']
            answer_tokens += [prm_tokenizer.eos_token_id]
            QA_ids = query_id + answer_tokens
            token_ids_list = QA_ids.tolist() if isinstance(QA_ids, torch.Tensor) else QA_ids
            token_list_len = len(token_ids_list)
            i+=1
        assert len(token_ids_list) <= max_seq_len
        (output,) = prm_model.encode(prompt_token_ids= token_ids_list)
        score = torch.sigmoid(output.outputs.data[-1]).cpu().item()
        # print(f'Response: {responses}, Score: {score}')
        # print(f"Data: {score!r}")
        outputs_score.append(score)




    assert len(outputs_score) == len(steps)
    print(f'outputs_score: {outputs_score}')
    return outputs_score,input_steps

# ============================================================
#                            eval
# ============================================================

def eval(eval_mode, gt, pred):
    if eval_mode == 'symeval':
        gt = str(gt)
        score = symeval_evaluator.batch_eq(ref_answers=[gt], pred_answers=[pred])[0]
    else:
        raise NotImplementedError
    return score



# ============================================================
#                            parse
# ============================================================

def parse_answer_boxed(pred_str):
    ## check fail case-1
    if 'boxed' not in pred_str:
        return ""
    ## check fail case-2
    ans = pred_str.split("boxed")
    if len(ans) == 1:
        return ""
    ## check fail case-3
    ans = ans[-1]
    if len(ans) == 0:
        return ""
    ##
    try:
        if ans[0] == "{":
            stack = 1
            a = ""
            for c in ans[1:]:
                if c == "{":
                    stack += 1
                    a += c
                elif c == "}":
                    stack -= 1
                    if stack == 0:
                        break
                    a += c
                else:
                    a += c
        else:
            a = ans.split("$")[0].strip()
    except:
        return ""
    return a


def parse_pred_answer(parse_mode, pred_str):
    if parse_mode == 'parse_boxed':
        pred = parse_answer_boxed(pred_str)
    else:
        raise NotImplementedError
    return pred

# ============================================================
#                           utils
# ============================================================
def filter_exist(dataset, results,question_id):
    index_lst = []
    for data in results:
        index_lst.append(data[question_id])
    dataset2 = []
    for data in dataset:
        if data[question_id] not in index_lst:
            dataset2.append(data)
    return dataset2
def filter_right(dataset,gt_answer_key):
    dataset2 = []
    for data in dataset:
        output = data['generation_info']['output']
        pred_answer = parse_pred_answer('parse_boxed', output)
        gt_answer = data[gt_answer_key]
        is_correct = eval('symeval', gt_answer, pred_answer)
        if is_correct != True:
            dataset2.append(data)
    return dataset2
def filter_round(dataset):
    dataset2 = []
    for data in dataset:
        round = data['generation_info']['num_round']
        if round <= 6:
            dataset2.append(data)
    return dataset2




def load_jsonl(path):
    dataset = []
    with open(path, 'r') as f:
        for line in f:
            data = json.loads(line)
            dataset.append(data)
    return dataset

def load_json(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

def save_jsonl(x, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as file:
        for obj in x:
            json.dump(obj, file, ensure_ascii=False)
            file.write('\n')

def save_json(x, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as f:
        json.dump(x, f, indent=4, ensure_ascii=False)
    print('saved to: ', save_path)

def get_save_path(args):
    # filename
    if args.start_id == 0 and args.end_id == -1:
        save_file_basename = 'output'
    else:
        save_file_basename = f'{args.start_id}_{args.end_id}'
    #
    # save_dir = os.path.join(
    #         args.save_root_dir,
    #         f'[max-interrupt-{args.max_interrupts}]__[interrupt-signals-{args.interrupt_signals_version}]__[deepen-prompt-{args.deepen_prompt_version}]__[system-prompt-{args.system_prompt_version}]__[question-prefix-{args.question_prefix_version}]__[question-suffix-{args.question_suffix_version}]',
    #         f'[temp-{args.temperature}]__[topp-{args.top_p}]__[maxtoken-{args.max_tokens}]__[seed-{args.seed}]'
    #     )
    # final_path
    save_path_jsonl = os.path.join(args.save_root_dir, 'jsonl', f'{save_file_basename}.jsonl')
    save_path_json = os.path.join(args.save_root_dir, 'json', f'{save_file_basename}.json')
    print(f'save_path_jsonl: {save_path_jsonl}')
    print(f'save_path_json: {save_path_json}')
    os.makedirs(os.path.dirname(save_path_jsonl), exist_ok=True)
    os.makedirs(os.path.dirname(save_path_json), exist_ok=True)
    return save_path_jsonl, save_path_json

def load_data(data_path, start_id, end_id):
    # load
    data_list = load_jsonl(data_path)
    print('num_data (loaded): ', len(data_list))
    # select a subset
    if start_id == 0 and end_id == -1:
        print('==> No subset, use all data ...')
    else:
        print('==> Select a subset ...')
        if end_id == -1:
            data_list = data_list[start_id:]
        else:
            data_list = data_list[start_id:end_id]
        print('num_data (subset): ', len(data_list))

    return data_list

def choice_dict_to_str(x):
    y = ''
    for k, v in x.items():
        k = k.upper()
        v = v.strip(' \n')
        y += f'{k}: {v}\n'
    y = y.strip(' \n')
    return y

def choice_list_to_dict(lst):
    return {chr(65 + i): v for i, v in enumerate(lst)}

def get_question_prefix(mode):
    if mode == 'v0':
        prefix = ""
    elif mode == 'v1':
        prefix = "Please reason step by step, and put your final answer within \\boxed{}. "
    elif mode == 'v2':
        prefix = "Please reason step by step, and put your final answer within \\boxed{}.\n"
    elif mode == 'v3':
        prefix = "Please reason step by step, and put your final answer within \\boxed{}.\n\n"
    else:
        raise NotImplementedError
    print(f'question prefix ({mode}): {prefix}')
    return prefix

def get_question_suffix(mode):
    if mode == 'v0':
        suffix = ""
    elif mode == 'v1':
        suffix = " Please reason step by step, and put your final answer within \\boxed{}."
    elif mode == 'v2':
        suffix = "\nPlease reason step by step, and put your final answer within \\boxed{}."
    elif mode == 'v3':
        suffix = "\n\nPlease reason step by step, and put your final answer within \\boxed{}."
    else:
        raise NotImplementedError
    print(f'question suffix ({mode}): {suffix}')
    return suffix

def get_system_prompt(mode):
    if mode == 'v0':
        prompt = ""
    else:
        raise NotImplementedError
    print(f'system prompt ({mode}): {prompt}')
    return prompt

def get_deepen_prompt(mode):
    if mode == 'v1':
        prompt = " Wait, this is a promising idea. Let's dive deeper into this idea and do not give up easily. "
    elif mode == 'v2':
        prompt = " Wait, this is a promising idea. But as we dive deeper into it, let's also critically reconsider the earlier points to ensure they are valid."
    elif mode == 'v3':
        prompt = "Wait, let's dive deeper into this idea."
    else:
        raise NotImplementedError
    print(f'deepen prompt ({mode}): {prompt}')
    return prompt

def get_interrupt_signals(mode):

    selected_ids = [int(i) for i in mode.split('-')]
    signals = []
    for i in selected_ids:
        signals.append(signal_pool[i])
    print(f'interrupt signals ({mode}): {signals}')
    return signals

def get_sampling_params(args):
    # https://docs.vllm.ai/en/latest/api/inference_params.html#sampling-params
    sampling_params = {}
    if args.temperature is not None:
        sampling_params['temperature'] = args.temperature
    if args.top_p is not None:
        sampling_params['top_p'] = args.top_p
    if args.top_k is not None:
        sampling_params['top_k'] = args.top_k
    if args.repetition_penalty is not None:
        sampling_params['repetition_penalty'] = args.repetition_penalty
    if args.max_tokens is not None:
        sampling_params['max_tokens'] = args.max_tokens
    if args.seed is not None:
        sampling_params['seed'] = args.seed

    print('sampling_params: ', sampling_params)
    return sampling_params

def flexible_find_v0(text, sentence, start_pos=0):
    # 1. exact match
    idx = text.find(sentence, start_pos)
    if idx != -1:
        return idx

    # 2. ignore capitalization
    idx = text.lower().find(sentence.lower(), start_pos)
    if idx != -1:
        print(f"Found match by ignoring case")
        return idx

    # 3. ignore extra spaces
    clean_text = ' '.join(text.split())
    clean_sentence = ' '.join(sentence.split())
    idx = clean_text.find(clean_sentence)
    if idx != -1:
        words_before = len(clean_text[:idx].split())
        original_words = text.split()
        original_idx = 0
        for i in range(min(words_before, len(original_words))):
            original_idx += len(original_words[i]) + 1
        print(f"by ignoring extra spaces found match")
        return original_idx

    # 4. ignore capitalization and extra spaces
    idx = clean_text.lower().find(clean_sentence.lower())
    if idx != -1:
        words_before = len(clean_text[:idx].split())
        original_words = text.split()
        original_idx = 0
        for i in range(min(words_before, len(original_words))):
            original_idx += len(original_words[i]) + 1
        print(f"by ignoring capitalization and extra spaces found match")
        return original_idx

    # 5. partial match (first 60 characters)
    if len(sentence) > 60:
        partial_sentence = sentence[:60]
        idx = text.find(partial_sentence, start_pos)
        if idx != -1:
            print(f"by partial match found match")
            return idx

    return -1
def extract_steps_content_v0(pred_output, steps_text):
    # get the thinking text from the prediction output
    if '<think>' in pred_output and '</think>' in pred_output:
        thinking_text = pred_output.split('<think>\n')[1].split('</think>')[0]
    else:
        thinking_text = pred_output

    import re

    step_pattern = r'<step number="([^"]+)">\s*<objective>([^<]+)</objective>\s*<start>([^<]+)</start>\s*<end>([^<]+)</end>(.*?)</step>'

    matches = re.findall(step_pattern, steps_text, re.DOTALL)

    substep_pattern = r'<substep (?:number|id)="([^"]+)">\s*<objective>([^<]+)</objective>\s*<start>([^<]+)</start>\s*<end>([^<]+)</end>\s*</substep>'

    initial_pos = 0
    steps_content = []

    for match in matches:
        if len(match) >= 5:
            step_num, objective, start_sentence, end_sentence, inner_content = match
        else:
            step_num, objective, start_sentence, end_sentence = match
            inner_content = ""

        start_idx = flexible_find_v0(thinking_text, start_sentence, initial_pos)

        if start_idx == -1:
            print("Warning: Cannot find the start sentence of step", step_num, "in the original text")
            return False

        if initial_pos != start_idx:
            miss_content = thinking_text[initial_pos:start_idx]
            if len(miss_content.strip()) > 0:
                print(f"Warning: There is uncovered content before step {step_num}: {miss_content[:50]}...")

        end_idx = flexible_find_v0(thinking_text, end_sentence, start_idx)
        if end_idx == -1:
            print("Warning: Cannot find the end sentence of step", step_num, "in the original text")
            return False

        end_pos = end_idx + len(end_sentence)

        step_content = thinking_text[start_idx:end_pos]

        step_obj = {
            'step_number': step_num,
            'objective': objective,
            'content': step_content,
            'start_pos': start_idx,
            'end_pos': end_pos,
        }


        if inner_content and re.search(substep_pattern, inner_content):
            substep_matches = re.findall(substep_pattern, inner_content)
            if substep_matches:
                substeps = []

                for substep_match in substep_matches:
                    substep_id, substep_obj, substep_start, substep_end = substep_match


                    substep_start_idx = flexible_find_v0(step_content, substep_start)

                    if substep_start_idx == -1:
                        print(f"Warning: Cannot find the start sentence of substep {substep_id} in the step content")
                        continue

                    substep_end_idx = flexible_find_v0(step_content, substep_end, substep_start_idx)

                    if substep_end_idx == -1:
                        print(f"Warning: Cannot find the end sentence of substep {substep_id} in the step content")
                        continue


                    substep_end_pos = substep_end_idx + len(substep_end)
                    substep_content = step_content[substep_start_idx:substep_end_pos]

                    substeps.append({
                        'step_number': substep_id,
                        'objective': substep_obj,
                        'content': substep_content,
                        'start_pos': substep_start_idx,
                        'end_pos': substep_end_pos,
                        'is_substep': True
                    })

                if substeps:
                    step_obj['substeps'] = substeps

        initial_pos = end_pos
        steps_content.append(step_obj)

    return steps_content
def flexible_find_v1(text, sentence, start_pos=0):
    # 1. Exact match
    idx = text.find(sentence, start_pos)
    if idx != -1:
        return idx

    # 2. Ignore case
    idx = text.lower().find(sentence.lower(), start_pos)
    if idx != -1:
        print(f"Found match by ignoring case")
        return idx

    # 3. Ignore extra spaces
    clean_text = ' '.join(text.split())
    clean_sentence = ' '.join(sentence.split())
    # Need to recalculate the start position because the number of spaces has changed
    # Simplified handling here, search directly from the beginning
    idx = clean_text.find(clean_sentence,start_pos)
    if idx != -1:
        # Find the corresponding position in the original text
        words_before = len(clean_text[:idx].split())
        original_words = text.split()
        original_idx = 0
        for i in range(min(words_before, len(original_words))):
            original_idx += len(original_words[i]) + 1
        print(f"Found match by ignoring extra spaces")
        return original_idx

    # 4. Ignore both case and extra spaces
    idx = clean_text.lower().find(clean_sentence.lower(),start_pos)
    if idx != -1:
        words_before = len(clean_text[:idx].split())
        original_words = text.split()
        original_idx = 0
        for i in range(min(words_before, len(original_words))):
            original_idx += len(original_words[i]) + 1
        print(f"Found match by ignoring both case and extra spaces")
        return original_idx

    # 5. Try partial match (take the first 60 characters)
    if len(sentence) > 60:
        partial_sentence = sentence[:60]
        idx = text.find(partial_sentence, start_pos)
        if idx != -1:
            print(f"Found result through partial match")
            return idx

    return -1

def fill_missing_intervals(starts, ends, total_length, original_need_process):
    """Fill in the missing intervals, and mark all newly added intervals as needing processing"""
    filled_starts = []
    filled_ends = []
    new_need_process = []

    # Record the indices of all new intervals (including the originally processed ones)
    all_processed_indices = set(original_need_process)
    new_index = 0

    # 1. Process the starting missing interval (0 -> first start)
    if not starts or starts[0] != 0:
        filled_starts.append(0)
        first_real_start = starts[0] if starts else total_length
        filled_ends.append(first_real_start)
        new_need_process.append(new_index)  # The new interval needs processing
        new_index += 1

    # 2. Process the original intervals
    for i in range(len(starts)):
        # Add the original interval
        filled_starts.append(starts[i])
        filled_ends.append(ends[i])

        # Mark the original intervals that need processing
        if i in original_need_process:
            new_need_process.append(new_index)

        # Check and add the intermediate missing interval
        if i < len(starts) - 1 and ends[i] < starts[i+1]:
            new_index += 1
            filled_starts.append(ends[i])
            filled_ends.append(starts[i+1])
            new_need_process.append(new_index)  # The new interval needs processing

        new_index += 1

    # 3. Process the ending missing interval
    if ends and ends[-1] < total_length:
        filled_starts.append(ends[-1])
        filled_ends.append(total_length)
        new_need_process.append(new_index)  # The new interval needs processing

    return filled_starts, filled_ends, new_need_process
def extract_steps_content_v1(pred_output, steps_text):
    """
    Extract the full content of each step from the original thinking process, including substeps.

    Args:
    - pred_output: The original thinking process text.
    - steps_text: The text containing the step structure (possibly in an XML-like format).

    Returns:
    - A list containing the full content of each step.
    """
    # Get the thinking content

    if '<think>' in pred_output and '</think>' in pred_output:
        thinking_text = pred_output.split('<think>\n')[1].split('</think>')[0]
    else:
        thinking_text = pred_output

    # Use regular expressions to extract step information
    import re

    # Prioritize checking step formats with the 'number' attribute
    step_pattern = r'<step number="([^"]+)">\s*<objective>([^<]+)</objective>\s*<start>([^<]+)</start>\s*<end>([^<]+)</end>(.*?)</step>'

    # Extract steps, including their internal content
    matches = re.findall(step_pattern, steps_text, re.DOTALL)


    initial_pos = 0
    steps_content = []
    start_positions = []
    end_positions = []
    start_positions = []
    end_positions = []
    initial_pos = 0
    need_spilt = False
    for match in matches:
        step_num, objective, start_sentence, end_sentence,inner_content = match
        start_idx = flexible_find_v1(thinking_text, start_sentence, initial_pos)
        end_idx = flexible_find_v1(
            thinking_text,
            end_sentence,
            start_idx + len(start_sentence) if start_idx != -1 else initial_pos
        )
        if start_idx == -1 or end_idx == -1:
            need_spilt = True
        if start_idx == -1 and end_idx == -1:
            print(f"Error: Step {step_num} not found in the original text.")
            continue  # Skip invalid steps

        start_positions.append(start_idx)
        end_positions.append((end_idx, end_sentence))

        # Update initial_pos
        if end_idx != -1:
            initial_pos = end_idx + len(end_sentence)
        elif start_idx != -1:
            initial_pos = start_idx + len(start_sentence)
    # import  pdb;pdb.set_trace()
    # 2. Preprocessing: Merge cases where end[i] == -1 and start[i+1] == -1
    if not need_spilt:
        print('No steps need to be split, all steps were found')
        for i,(start_idx,(end_idx,end_sentence)) in enumerate(zip(start_positions, end_positions)):
            # Extract step content
            end_pos = end_idx + len(end_sentence)
            step_content = thinking_text[start_idx:end_pos].strip()
            steps_content.append({
                'content': step_content,
                'start_pos': start_idx,
                'end_pos': end_pos
            })
        return steps_content
    if need_spilt:
        i = 0
        while i < len(end_positions) - 1:
            current_end, current_end_sentence = end_positions[i]
            next_start = start_positions[i+1]
            next_end, next_end_sentence = end_positions[i+1]

            if current_end == -1 and next_start == -1:
                if next_end != -1:
                    # Merge: Use next_end as the end of the current step
                    end_positions[i] = (next_end, next_end_sentence)
                    # Delete the next step (as it has been merged)
                    del start_positions[i+1]
                    del end_positions[i+1]
                    # Do not increment i, continue checking if the new i also needs merging
                else:
                    # If next_end is also -1, cannot merge, skip
                    i += 1
            else:
                i += 1
        #
        # 3. Handle the remaining -1 cases (now there are no end[i] == -1 and start[i+1] == -1)
        new_start_positions = []
        new_end_positions = []
        text_length = len(thinking_text)
        need_process = []
        for i, (start_idx, (end_idx, end_sentence)) in enumerate(zip(start_positions, end_positions)):
            current_start = start_idx
            current_end = end_idx

            # Handle end_idx == -1 (at this point, the next start_idx will definitely not be -1)
            if current_end == -1:
                need_process.append(i)
                if i == len(end_positions) - 1:
                    current_end = text_length
                else:
                    next_start = start_positions[i+1]
                    current_end = next_start  # Since it has been preprocessed, next_start will not be -1

            # Handle start_idx == -1
            if current_start == -1:
                need_process.append(i)
                if i == 0:
                    current_start = 0
                else:
                    prev_end, _ = end_positions[i-1]
                    current_start = prev_end if prev_end != -1 else 0

            new_start_positions.append(current_start)
            new_end_positions.append(current_end)
        #
        # 4. Check the results
        print(f"Original start_positions: {new_start_positions}")
        print(f"Original end_positions: {new_end_positions}")
        print(f"Original need_process: {need_process}")
        new_start_positions, new_end_positions,need_process = fill_missing_intervals(new_start_positions, new_end_positions, text_length,need_process)
        print(f"Filled start_positions: {new_start_positions}")
        print(f"Filled end_positions: {new_end_positions}")
        print(f'Filled need_process: {need_process}')
        assert -1 not in new_start_positions and -1 not in new_end_positions
        for i, (start_idx, end_idx) in enumerate(zip(new_start_positions, new_end_positions)):
            assert start_idx != -1 and end_idx != -1
            assert start_idx < end_idx, f"Step {i}'s start_pos >= end_pos: {start_idx} >= {end_idx}"
            if i < len(new_start_positions) - 1:
                assert end_idx <= new_start_positions[i+1], f"Step {i}'s end_pos >= next step's start_pos: {end_idx} >= {new_start_positions[i+1]}"

        # Final result
        start_positions = new_start_positions
        end_positions = new_end_positions

        for i,(start_idx,end_idx) in enumerate(zip(start_positions, end_positions)):
            # Extract step content
            step_content = thinking_text[start_idx:end_idx].strip()
            # Process substeps
            if i in need_process and len(step_content) > 1000:
                ### first try to split by the signal pool
                # 1. Find all signals and their positions
                signal_positions = []
                for signal in signal_pool.values():
                    pos = step_content.find(signal)
                    while pos != -1:
                        signal_positions.append((pos, signal))
                        pos = step_content.find(signal, pos + 1)
                # 2. Sort signals by position
                if signal_positions:
                    #
                    signal_positions.sort(key=lambda x: x[0])
                    text_parts = []

                    last_pos = 0
                    for pos, signal in signal_positions:
                        if pos > last_pos:  # There is a text segment
                            text = step_content[last_pos:pos].strip()
                            if text:
                                # Calculate absolute position
                                abs_start = start_idx + last_pos
                                abs_end = start_idx + pos
                                steps_content.append({
                                    'content': text,
                                    'start_pos': abs_start,
                                    'end_pos': abs_end,
                                    'split': True  # Mark as split
                                })
                        last_pos = pos + len(signal)  # Skip the signal

                    # Process the last segment
                    if last_pos < len(step_content):
                        text = step_content[last_pos:].strip()
                        if text:
                            abs_start = start_idx + last_pos
                            abs_end = start_idx + len(step_content)
                            steps_content.append({
                                'content': text,
                                'start_pos': abs_start,
                                'end_pos': abs_end,
                                'split': True  # Mark as split
                            })
                else:
                    ### if not, then split by the .
                    #
                    sentences = re.split(r'(?<=[.!?])', step_content)  # Split by sentence
                    sentences = [s.strip() for s in sentences if s.strip()]  # Remove empty sentences

                    # Calculate target paragraph length (dynamic adjustment)
                    avg_length = len(step_content) / max(1, len(step_content) // 500)
                    current_paragraph = []
                    current_length = 0
                    current_start_pos = start_idx  # Track the start position of the current paragraph

                    for sentence in sentences:
                        sentence_length = len(sentence)

                        # If the current paragraph is too short, continue adding sentences
                        if current_length + sentence_length < avg_length:
                            current_paragraph.append(sentence)
                            current_length += sentence_length
                        else:
                            # Save the current paragraph (process and save on the fly)
                            if current_paragraph:
                                paragraph_text = ' '.join(current_paragraph)
                                steps_content.append({
                                    'content': paragraph_text,
                                    'start_pos': current_start_pos,
                                    'end_pos': current_start_pos + len(paragraph_text),
                                    'split': True  # Mark as split
                                })

                            # Start a new paragraph
                            current_paragraph = [sentence]
                            current_length = sentence_length
                            current_start_pos = start_idx + step_content.find(sentence, current_start_pos - start_idx)

                    # Process the last paragraph
                    if current_paragraph:
                        paragraph_text = ' '.join(current_paragraph)
                        steps_content.append({
                            'content': paragraph_text,
                            'start_pos': current_start_pos,
                            'end_pos': end_idx,  # Directly use the end position of the parent block
                            'split': True  # Mark as split
                        })
            else:
                steps_content.append({
                    'content': step_content,
                    'start_pos': start_idx,
                    'end_pos': end_idx
                })

        for i, step in enumerate(steps_content):
            # 1. Check end_pos > start_pos
            assert step['end_pos'] > step['start_pos'], \
                f"Step {i}'s end_pos <= start_pos: {step}"

            # # 2. Check that the interval between consecutive steps does not exceed 100
            # if i > 0:
            #     prev_end = steps_content[i-1]['end_pos']
            #     curr_start = step['start_pos']
            #     assert curr_start - prev_end <= 100, \
            #         f"The interval between step {i-1} and {i} is too large: {prev_end} -> {curr_start}"

    return steps_content
def extract_steps_content_v2(pred_output):
    """
    Split into N segments using "\n\n" as a delimiter, then merge every 5 segments into one step.

    Args:
    - pred_output: The original output text.

    Returns:
    - A list containing the content and position information of each step.
    """
    # Split into N segments using "\n\n" as a delimiter
    segments = pred_output.split('\n\n')
    steps_content = []

    # Calculate the starting position of each segment in the original text
    segment_positions = []
    current_pos = 0
    for segment in segments:
        segment_length = len(segment)
        segment_positions.append((current_pos, current_pos + segment_length))
        # Add the length of the segment and the length of the delimiter "\n\n" (2 newline characters = 2)
        current_pos += segment_length + 2

    # Merge every 5 segments into one step
    for i in range(0, len(segments), 5):
        chunk = segments[i:i+5]
        step_content = '\n\n'.join(chunk)

        # Calculate the start and end positions of this step in the original text
        start_pos = segment_positions[i][0] if i < len(segment_positions) else 0
        end_idx = min(i+5-1, len(segment_positions)-1)
        end_pos = segment_positions[end_idx][1] if end_idx < len(segment_positions) else len(pred_output)

        step_obj = {
            'content': step_content,
            'start_pos': start_pos,
            'end_pos': end_pos
        }
        steps_content.append(step_obj)

    # Validate the results
    for i, step in enumerate(steps_content):
        assert step['end_pos'] > step['start_pos'], f"Step {i}'s end position should be greater than its start position"
        if i > 0:
            assert step['start_pos'] >= steps_content[i-1]['end_pos'], f"Step {i}'s start position should be after the previous step's end position"

    return steps_content
def extract_steps_content_v3(pred_output):
    """
    Split into N segments using "\n\n" as a delimiter, and then each segment becomes a step.

    Args:
    - pred_output: The original output text.

    Returns:
    - A list containing the content and position information of each step.
    """
    # Split into N segments using "\n\n" as a delimiter
    segments = pred_output.split('\n\n')
    steps_content = []

    # Calculate the starting position of each segment in the original text
    segment_positions = []
    current_pos = 0
    for segment in segments:
        segment_length = len(segment)
        segment_positions.append((current_pos, current_pos + segment_length))
        # Add the length of the segment and the length of the delimiter "\n\n" (2 newline characters = 2)
        current_pos += segment_length + 2

    # Each segment becomes a step
    for segment,pos in zip(segments,segment_positions):
        step_content = segment

        # Calculate the start and end positions of this step in the original text
        start_pos = pos[0]
        end_pos = pos[1]

        step_obj = {
            'content': step_content,
            'start_pos': start_pos,
            'end_pos': end_pos
        }
        steps_content.append(step_obj)
    # Validate the results
    for i, step in enumerate(steps_content):
        assert step['end_pos'] > step['start_pos'], f"Step {i}'s end position should be greater than its start position"
        if i > 0:
            assert step['start_pos'] >= steps_content[i-1]['end_pos'], f"Step {i}'s start position should be after the previous step's end position"
    print('steps_content: ',steps_content)
    return steps_content

def extract_steps_content_v4(pred_output):
    steps_content = []
    text_length = len(pred_output)

    # 1. Find the positions of all signal words
    signal_positions = []
    for signal in signal_pool.values():
        pos = pred_output.find(signal)
        while pos != -1:
            signal_positions.append((pos, signal))
            pos = pred_output.find(signal, pos + 1)

    # Sort signal words by position
    signal_positions.sort(key=lambda x: x[0])

    # 2. Split the text by signal words
    segments = []
    segment_positions = []

    last_pos = 0
    for pos, signal in signal_positions:
        # Add the text before the signal word; the signal word is recorded in the next step
        if pos > last_pos:
            segment = pred_output[last_pos:pos]
            if segment.strip():
                segments.append(segment)
                segment_positions.append((last_pos, pos))

        last_pos = pos

    # Add the last part of the text
    if last_pos < text_length:
        segment = pred_output[last_pos:]
        if segment.strip():
            segments.append(segment)
            segment_positions.append((last_pos, text_length))

    # If no signal words are found, treat the entire text as a single paragraph
    if not signal_positions:
        segments = [pred_output]
        segment_positions = [(0, text_length)]

    # 3. Process each segment, deciding whether to split further based on length
    for i, (segment, (start_pos, end_pos)) in enumerate(zip(segments, segment_positions)):
        # If the length is moderate (less than 200 words), add it directly as one step
        words = segment.split()
        if len(words) <= 200:
            steps_content.append({
                'content': segment,
                'start_pos': start_pos,
                'end_pos': end_pos
            })
        else:
            # If it's too long, split further by "\n\n", then dynamically combine into steps of about 150 words
            subsegments = segment.split('\n\n')
            sub_start = start_pos

            current_step = []  # The step currently being built
            current_word_count = 0  # Word count of the current step
            current_step_start = sub_start  # Start position of the current step

            for subseg in subsegments:
                if not subseg.strip():
                    # Skip empty paragraphs, but update the position
                    sub_start += len(subseg) + 2
                    continue

                # Calculate the word count of the current sub-segment
                words_in_subseg = len(subseg.split())

                # If adding this sub-segment would exceed 150 words, or if it is a very long sub-segment
                if (current_word_count > 0 and current_word_count + words_in_subseg > 150) or words_in_subseg > 200:
                    # First, save the currently accumulated step
                    if current_step:
                        combined_content = '\n\n'.join(current_step)
                        steps_content.append({
                            'content': combined_content,
                            'start_pos': current_step_start,
                            'end_pos': sub_start - 2  # Subtract the length of the delimiter
                        })
                        # Reset the current step
                        current_step = []
                        current_word_count = 0
                        current_step_start = sub_start

                # Add the current sub-segment to the step
                current_step.append(subseg)
                current_word_count += words_in_subseg
                sub_end = sub_start + len(subseg)

                # If this sub-segment is very long (over 200 words), treat it as a separate step
                if words_in_subseg > 200:
                    combined_content = '\n\n'.join(current_step)
                    steps_content.append({
                        'content': combined_content,
                        'start_pos': current_step_start,
                        'end_pos': sub_end
                    })
                    # Reset the current step
                    current_step = []
                    current_word_count = 0
                    current_step_start = sub_end + 2

                # Update the sub-segment start position
                sub_start = sub_end + 2  # +2 is the length of the delimiter "\n\n"

            # Process the last remaining step
            if current_step:
                combined_content = '\n\n'.join(current_step)
                steps_content.append({
                    'content': combined_content,
                    'start_pos': current_step_start,
                    'end_pos': sub_start - 2  # Subtract the length of the last delimiter
                })


    # Validate the results
    for i, step in enumerate(steps_content):
        assert step['end_pos'] > step['start_pos'], f"Step {i}'s end position should be greater than its start position"
        if i > 0:
            assert step['start_pos'] >= steps_content[i-1]['end_pos'], f"Step {i}'s start position should be after the previous step's end position"
    print('steps_content: ',steps_content)
    return steps_content
def build_input(steps,question):
    question_wgt = question + '\n\n###\n\nThe reference answer is: There is no reference answer for this question.'
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": question_wgt}
    ]
    step_content = [step['content'] for step in steps]
    step_end_pos = [step['end_pos'] for step in steps]
    return (step_content,messages,step_end_pos)
# ============================================================
#                          generate
# ============================================================
def setup_openai_client(port):
    openai_api_key = "EMPTY"
    openai_api_base = f"http://localhost:{port}/v1"
    client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
    return client


# def get_token_number(model_path, text):
#     encoding = tiktoken.encoding_for_model(model_path)
#     num_token = len(encoding.encode(text))
#     return num_token


def streaming_generation_with_interruptions(
        messages,
        client,
        model_path,
        prompt_deepening,
        max_interrupts,
        interrupt_signals,
        sampling_params,
        normal_max_tokens,
        extend_max_tokens,
        prm_model,
        prm_tokenizer,
        stpes_divide,
        question,
        divide_step_method,
        threshold
    ) -> str:
    """
    Streaming generation function with interruption detection.
    Learned from: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html?ref=blog.mozilla.ai#chat-api

    Args:
        max_interrupts: Maximum number of interruptions.
            max_interrupts = 1: Can only be interrupted once.

    Returns:
        The final complete generated text.
    """

    round_idx = 1
    generation_full_history = []   # length = number of round
    is_finished = False
    should_interrupt = None
    ignored_signal_positions = {}  # Used to store ignored signals and their positions {signal: [positions]}
    missing_count = 0

    cached_response = ""  # Cache the response content from the previous round
    cached_steps = []     # Cache the steps from the previous round
    cached_scores = []    # Cache the scores from the previous round
    for m in messages:
        print(m)

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    assert sampling_params['temperature'] == 0.6, "temperature should be 0.0"
    assert sampling_params['top_p'] == 0.95, "top_p should be 1.0"
    assert sampling_params['max_tokens'] == 32768, "max_tokens should be 32768"
    # Loop until generation finished
    while True:
        ###
        if is_finished:
            break

        ###
        disable_interrupt = False
        if round_idx > max_interrupts:
            disable_interrupt = True
            print('==> Reach max_interrupts, Disable_interrupt')

        if round_idx == 1:
            print(f'\n\n==> [Round {round_idx}]')
        else:
            print(f'\n\n==> [Round {round_idx}] Resume generation: is_interrupt = {should_interrupt}')
            if not should_interrupt:   # This means we are not here because of an interrupt, so it must be because of a finish.
                break

        print('\n--- input ---')
        for m in messages:
            print(m)
        print(messages[-1]['content'][:100])
        print('......')
        print(messages[-1]['content'][-200:])


        ### stream generation
        if messages[-1]['role'] == 'assistant':    # Continuation mode
            print('\n--- sampling ---')
            print('continue_final_message: True')

            # update max_token for this round
            num_token_previous = len(tokenizer.encode(messages[-1]["content"]))
            print('num_token_previous: ', num_token_previous)
            num_token_remain = normal_max_tokens - num_token_previous
            print('num_token_remain: ', num_token_remain)
            sampling_params['max_tokens'] = num_token_remain
            if disable_interrupt and extend_max_tokens and round_idx > 1:   # When interruption is not allowed, expand max_tokens in hopes of getting a complete output at the end.
                sampling_params['max_tokens'] += extend_max_tokens
                print(f'extend {extend_max_tokens} max_tokens')
            print('max_tokens this round: ', sampling_params['max_tokens'])

            #
            output_stream = client.chat.completions.create(
                model=model_path,
                messages=messages,
                stream=True,
                extra_body={  # vLLM specific extension parameters
                    "continue_final_message": True,
                    "add_generation_prompt": False
                },
                **sampling_params,
            )
        else:
            print('\n--- sampling ---')
            print('continue_final_message: False')

            output_stream = client.chat.completions.create(
                model=model_path,
                messages=messages,
                stream=True,
                **sampling_params,
            )

        ### generate until 'should_interrupt'
        print('\n--- output ---')
        output_this_round = ""

        for chunk in output_stream:
            ## (1) Get 'output'
            current_output = chunk.choices[0].delta.content
            if not current_output:
                continue
            # print
            print(current_output, end="", flush=True)   # end="": no newline   flush=True: force immediate buffer flush for instant display
            # update
            output_this_round += current_output
            # # debug
            # if round_idx != 1:
            #     output_this_round += ' Alternately,'

            # Additionally, check if '</think>' is in output_this_round. If so, do not interrupt.
            if '</think>' in output_this_round:
                disable_interrupt = True
                # print('==> Found </think>, interruption not allowed')
            ## (2) Whether allow interrupt?
            if disable_interrupt:
                continue

            ## (3) Whether should interrupt?
            found_signal_info = {}
            for signal in interrupt_signals:
                max_ignore_pos = ignored_signal_positions.get(signal, 0)
                if signal in output_this_round[max_ignore_pos:]:    # Note: Detect from output_this_round, not current_output, as the signal can be split between two adjacent chunks.
                    print(f"\n==> Interrupt signal detected: <{signal}>")
                    found_signal_info[signal] = len(output_this_round[max_ignore_pos:].split(signal)[0]) + max_ignore_pos + len(signal)   # Calculate the position of the signal in output_this_round
                    assert found_signal_info[signal] > max_ignore_pos, f"found_signal_info[signal]: {found_signal_info[signal]}, max_ignore_pos: {max_ignore_pos}"
            if len(found_signal_info) > 0:
                final_signal = min(found_signal_info, key=found_signal_info.get)
                final_signal_pos = found_signal_info[final_signal]
                print(f"final_signal: {final_signal}")
                should_interrupt = True
                interrupt_signal_this_round = final_signal

            else:
                should_interrupt = False
                continue

            ## (4) Process for interrupt
            if should_interrupt:
                output_inherited_by_next_round = output_this_round[:final_signal_pos-len(final_signal)].strip(' \n')    # Take the content before the first signal and pass it to the next round
                # Here we need to check if the idea being replaced is a good one. If it is, we should interrupt to dive deeper. If not, we don't need to interrupt.
                score,step_content = check_score(messages,output_inherited_by_next_round,prm_model,prm_tokenizer,stpes_divide,question,divide_step_method,cached_response, cached_steps, cached_scores)
                if score == False:
                    print('failed to extract steps')
                    should_interrupt = False
                    # Set the position of the ignored signal
                    if final_signal not in ignored_signal_positions:
                        ignored_signal_positions[final_signal] = found_signal_info[final_signal]
                    else:
                        ignored_signal_positions[final_signal] = max(ignored_signal_positions[final_signal], found_signal_info[final_signal])
                    print(f'ignored_signal_positions: {ignored_signal_positions}')
                    continue
                else:
                    if messages[-1]['role'] == 'user':
                        cached_response = output_inherited_by_next_round
                    elif messages[-1]['role'] == 'assistant' and messages[-2]['role'] == 'user':
                        cached_response = messages[-1]['content'] + output_inherited_by_next_round
                    cached_steps = step_content
                    cached_scores = score
                if threshold == 'v0':
                    threshold_value = 0.69
                elif threshold == 'v1':
                    threshold_value = 0.70
                elif threshold == 'v2':
                    threshold_value = 0.71
                elif threshold == 'v3':
                    threshold_value = sum(score) / len(score)   # Take the average
                elif threshold == 'v4':
                    # EMA exponential smoothing
                    alpha = 0.8
                    threshold_value = score[0]
                    if len(score) > 1:
                        for i in range(1, len(score)):
                            threshold_value = alpha * score[i] + (1 - alpha) * threshold_value
                else:
                    raise ValueError(f"Invalid threshold value: {threshold}")
                # threshold = 0.71   # Take the average
                # threshold = sum(score) / len(score)   # Take the average
                print(f'score: {score}, threshold: {threshold_value}')
                if score[-1] >= threshold_value:
                    print('Score above the threshold. Time to Dive deeper !!')
                else:
                    print('Score below threshold. Ignoring this interrupt signal.')
                    missing_count += 1
                    if missing_count > 10:
                        disable_interrupt = True
                        print('==> Reach max_missing_count, Disable_interrupt')
                    should_interrupt = False
                    # Set the position of the ignored signal
                    if final_signal not in ignored_signal_positions:
                        ignored_signal_positions[final_signal] = found_signal_info[final_signal]
                    else:
                        ignored_signal_positions[final_signal] = max(ignored_signal_positions[final_signal], found_signal_info[final_signal])
                    print(f'ignored_signal_positions: {ignored_signal_positions}')
                    # Continue to the next chunk
                    continue
                current_output_token_num = len(tokenizer.encode(output_inherited_by_next_round))
                print('output_token_number this round: ', current_output_token_num)
                #
                if messages[-1]['role'] == 'user':
                    full_messages_this_round = messages + [{"role": "assistant", "content": output_inherited_by_next_round}]
                elif messages[-1]['role'] == 'assistant' and messages[-2]['role'] == 'user':
                    full_messages_this_round = messages[:-1] + [{"role": "assistant", "content": output_inherited_by_next_round}]
                else:
                    raise ValueError
                # update generation_history
                history_this_round = {
                    'round_id': round_idx,
                    'interrupt_signal':  interrupt_signal_this_round,
                    'this_round_output': output_this_round,
                    'this_round_output_token_number': current_output_token_num,
                    'this_round_output_inherited_by_next_round': output_inherited_by_next_round,
                    'this_round_sampling_params': sampling_params,
                    'this_round_input_messages': messages,
                    'this_round_full_messages': full_messages_this_round,
                    'this_round_score': score,
                    'this_round_step_content': step_content,
                }
                generation_full_history.append(history_this_round)

                # update message for next round
                new_input = output_inherited_by_next_round + prompt_deepening    # newly added input for next_round

                if messages[-1]['role'] == 'user':    # This means there was no previous interruption, so a new assistant message needs to be added.
                    assis_content = new_input
                    assis_message = {"role": "assistant", "content": assis_content}
                    messages.append(assis_message)
                elif messages[-1]['role'] == 'assistant' and messages[-2]['role'] == 'user':    # Then the existing assistant message needs to be updated.
                    assis_content = messages[-1]['content'] + new_input
                    assis_message = {"role": "assistant", "content": assis_content}
                    # messages[-1]['content'] = assis_content
                    messages = messages[:-1] + [assis_message]
                else:
                    raise ValueError

                # Exit: 'for chunk in output_stream:'
                output_stream.response.close()    # Explicitly close
                break


        ## Whether Generation is Finished
        # Note: If should_interrupt is true, it means this round's output still contains a change of thought, so the finish is voided and generation needs to continue. Therefore, check should_interrupt before checking finish.
        finish_reason = getattr(chunk.choices[0], "finish_reason", None)
        if finish_reason is not None:
            is_finished = True
            print("Finish reason:", finish_reason)

        round_idx += 1
        # here we will give each interrupt signal a beta to avoid dead loop for always change the thinking
        beta = 500
        ignored_signal_positions = {signal: beta for _, signal in signal_pool.items()}


    ### Finished
    print(f"\n==> Finished !!!")

    # update generation_history for the last round
    current_output_token_num = len(tokenizer.encode(output_this_round))
    if messages[-1]['role'] == 'user':
        full_messages_this_round = messages + [{"role": "assistant", "content": output_this_round}]
    elif messages[-1]['role'] == 'assistant' and messages[-2]['role'] == 'user':
        full_messages_this_round = messages[:-1] + [{"role": "assistant", "content": output_this_round}]
    else:
        raise ValueError

    history_this_round = {
        'mode': 'stream',
        'round_id': round_idx - 1,
        'finish_reason': finish_reason,
        'this_round_output': output_this_round,
        'this_round_output_token_number': current_output_token_num,
        'this_round_sampling_params': sampling_params,
        'this_round_input_messages': messages,
        'this_round_full_messages': full_messages_this_round,
        'this_round_score': None,
        'this_round_step_content': None,
    }
    generation_full_history.append(history_this_round)

    # collect the final version of full_response
    if messages[-1]['role'] == 'assistant':
        full_output = messages[-1]['content'] + output_this_round
    else:   # for the case without interruption
        full_output = output_this_round
    full_output = full_output.lstrip(' \n')
    if not full_output.startswith('<think>'):
        full_output = '<think>\n' + full_output
    #
    full_output_token_num = len(tokenizer.encode(full_output))
    print('token_number of full_output: ', full_output_token_num)

    #
    generation_info = {
        'history': generation_full_history,
        'num_round': len(generation_full_history),
        'full_output': full_output,
        'full_output_token_number': full_output_token_num,
        'finish_reason': finish_reason,
    }

    return generation_info



# ============================================================
#                            main
# ============================================================
def main():
    args = _parse_args()

    save_path_jsonl, save_path_json = get_save_path(args)
    ### resume
    results = []
    if os.path.exists(save_path_jsonl):
        print(f"\n==> Found existing results file: {save_path_jsonl}")
        print("==> Loading existing results and continuing processing...")
        results = load_jsonl(save_path_jsonl)
        print(f"==> Loaded {len(results)} processed data entries")

    ### client
    client = setup_openai_client(args.port)
    ### setup PRM and api for divide the step
    prm_model, prm_tokenizer, prm_sampling_params = setup_prm(args)
    ### setup step divider api
    stpes_divide= setup_step_divide()
    ### data
    data_list = load_data(args.data_path, args.start_id, args.end_id)
    print(f"==> Loading data from {args.data_path}, total {len(data_list)} entries")
    # filter right answer data
    if args.only_wrong:
        print(f"==> Processing only the incorrect problems")
        data_list = filter_right(data_list,args.gt_answer_key)
        print(f"==> After filtering correct problems, {len(data_list)} entries remain")
    # filter results
    data_list = filter_exist(data_list, results,args.question_id)
    print(f"==> After filtering, {len(data_list)} unprocessed entries remain")


    if len(data_list) == 0:
        print("==> All data has been processed, no need to continue")
        return

    question_prefix = get_question_prefix(args.question_prefix_version)
    question_suffix = get_question_suffix(args.question_suffix_version)

    system_prompt = get_system_prompt(args.system_prompt_version)
    deepen_prompt = get_deepen_prompt(args.deepen_prompt_version)

    ### generation
    print('\n---------------------- Begin Generation ----------------------')
    interrupt_signals = get_interrupt_signals(args.interrupt_signals_version)
    sampling_params = get_sampling_params(args)

    for data in tqdm(data_list):
        # initial message
        question = data[args.question_key]
        if 'gpqa' in args.data_path:
            choices_list = data['choice_list']
            choices_dict = choice_list_to_dict(choices_list)
            choices_str = choice_dict_to_str(choices_dict)
            user_content = 'Question: ' + question + '\n' + 'Options:\n' + choices_str  + question_suffix
        else:
            user_content = question_prefix + question + question_suffix

        messages = []
        if system_prompt != '':
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": user_content})

        # generation
        generation_info = streaming_generation_with_interruptions(
                messages,
                client,
                args.model_path,
                deepen_prompt,
                args.max_interrupts,
                interrupt_signals,
                sampling_params.copy(),
                args.max_tokens,
                args.extend_max_tokens,
                prm_model,
                prm_tokenizer,
                stpes_divide,
                question,
                args.divide_step_method,
                args.threshold,
            )

        # update
        generation_info['max_interrupts'] = args.max_interrupts
        generation_info['interrupt_signals'] = interrupt_signals
        generation_info['deepen_prompt'] = deepen_prompt

        data['generation_info'] = generation_info
        results.append(data)

        # save
        save_jsonl(results, save_path_jsonl)
        save_json(results, save_path_json)

    print('done!')


main()